#coding:utf-8

import rosbag
import cv2
import os
from cv_bridge import CvBridge
from cv_bridge import CvBridgeError
import sensor_msgs.point_cloud2 as pc2
import numpy as np
import copy
import json
import multiprocessing as mp
import traceback

import logging

fileHandler = logging.FileHandler('bag_extract.log', encoding='utf-8')
consoleHandler = logging.StreamHandler()
consoleHandler.setLevel(logging.DEBUG)
logFormat = "%(asctime)s [%(levelname)s]: %(message)s"
logging.basicConfig(format=logFormat, level=logging.DEBUG, handlers=[fileHandler, consoleHandler])
logger = logging.getLogger()

PROCESS_NUM = 1

PCD_ASCII_TEMPLATE = """VERSION 0.7
FIELDS x y z intensity ring timestamp
SIZE 4 4 4 4 4 8
TYPE F F F U U F
COUNT 1 1 1 1 1 1
WIDTH {}
HEIGHT 1
VIEWPOINT 0 0 0 1 0 0 0
POINTS {}
DATA ascii
"""

PCD_BINARY_TEMPLATE = """VERSION 0.7
FIELDS x y z intensity ring timestamp
SIZE 4 4 4 4 4 8
TYPE F F F U U F
COUNT 1 1 1 1 1 1
WIDTH {}
HEIGHT 1
VIEWPOINT 0 0 0 1 0 0 0
POINTS {}
DATA binary
"""


ARS548_PCD_BINARY_TEMPLATE = """VERSION 0.7
FIELDS x y z intensity dopple
SIZE 4 4 4 4 4 
TYPE F F F F F
COUNT 1 1 1 1 1
WIDTH {}
HEIGHT 1
VIEWPOINT 0 0 0 1 0 0 0
POINTS {}
DATA binary
"""
ARBE_PCD_BINARY_TEMPLATE = """VERSION 0.7
FIELDS x y z intensity dopple
SIZE 4 4 4 4 4 
TYPE F F F F F
COUNT 1 1 1 1 1
WIDTH {}
HEIGHT 1
VIEWPOINT 0 0 0 1 0 0 0
POINTS {}
DATA binary
"""

class BagExtractor():
    def __init__(self, bag_file, dst):
        self.bag_file = bag_file
        self.dst = dst
        self.bridge = CvBridge()
        self.fps = {'img': {}, 'pcd': {}}

    def is_fetched(self, key, timestr):
        # 是否被抽取这一帧
        if timestr[:10] in self.fps[key]:
            return False
        else:
            self.fps[key].update({timestr[:10]: 1})
            return True

    def extract_img_pcd(self, no_compress_img_topic_list, compress_img_topic_list , pcd_topic,ars548_topic,arbe_topic, imu_topic):
        '''
        :param img_topic: 图像topic名称
        :param pcd_topic: 点云topic名称
        :return:
        '''
        imus = {}
        print("prepare read ",self.bag_file,",please waiting")
        with rosbag.Bag(self.bag_file, 'r') as bag:
            for topic, msg, t in bag.read_messages():
                if topic in no_compress_img_topic_list:  # 图像的topic；
                    #timestr = str(msg.header.stamp.secs)
                    timestr = "%.9f" % msg.header.stamp.to_sec()
                    imgpath = os.path.join(self.dst, topic[1:].replace('/', '_')+'_'+timestr+'.png')
                    # if self.is_fetched('img', timestr):
                    if not os.path.exists(imgpath):
                        self.to_no_img(imgpath, msg)
                    # self.gen_pic(imgpath, msg)
                    print( "[camera]Extract img file {}".format(imgpath))
                elif topic in compress_img_topic_list:  # 图像的topic；
                    #timestr = str(msg.header.stamp.secs)
                    timestr = "%.9f" % msg.header.stamp.to_sec()
                    imgpath = os.path.join(self.dst, topic[1:].replace('/', '_')+'_'+timestr+'.png')
                    # if self.is_fetched('img', timestr):
                    if not os.path.exists(imgpath):
                        self.to_img(imgpath, msg)
                    # self.gen_pic(imgpath, msg)
                    print( "[camera]Extract img file {}".format(imgpath))
                elif topic in pcd_topic: # 点云的topic
                    timestr = "%.9f" % (msg.header.stamp.to_sec()-0.05)
                    # timestr = str(msg.header.stamp.secs)
                    pcdpath = os.path.join(self.dst, topic[1:].replace('/', '_')+'_'+timestr+'.pcd')
                    # if msg.header.stamp.to_sec() > 1599811721.597293000:
                    # if self.is_fetched('pcd', timestr):
                    if not os.path.exists(pcdpath):
                        self.to_pcd_binary(pcdpath, msg)
                    print("[lidar]Extract pcd file {}".format(pcdpath))
                elif topic == ars548_topic: # 点云的topic
                    timestr = "%.9f" % msg.header.stamp.to_sec()
                    # timestr = str(msg.header.stamp.secs)
                    pcdpath = os.path.join(self.dst, topic[1:].replace('/', '_')+'_'+timestr+'.pcd')
                    # if msg.header.stamp.to_sec() > 1599811721.597293000:
                    # if self.is_fetched('pcd', timestr):
                    if not os.path.exists(pcdpath):
                        self.to_548_binary(pcdpath, msg)
                    print("[ars548]Extract pcd file {}".format(pcdpath))   
                elif topic == arbe_topic: # 点云的topic
                    timestr = "%.9f" % msg.header.stamp.to_sec()
                    # timestr = str(msg.header.stamp.secs)
                    pcdpath = os.path.join(self.dst, topic[1:].replace('/', '_')+'_'+timestr+'.pcd')
                    # if msg.header.stamp.to_sec() > 1599811721.597293000:
                    # if self.is_fetched('pcd', timestr):
                    if not os.path.exists(pcdpath):
                        self.to_arbe_binary(pcdpath, msg)
                    print("[arbe]Extract pcd file {}".format(pcdpath))                               
                elif topic == imu_topic:
                    timestr = "%.9f" % msg.header.stamp.to_sec()
                    # imupath = os.path.join(self.dst, topic[1:]+'_'+timestr+'.imu')
                    if not msg.header.stamp.secs in imus:
                        imus.update({msg.header.stamp.secs: []})
                    imus[msg.header.stamp.secs].append({
                        'timestamp': timestr,
                        'position': [msg.pose.position.x, msg.pose.position.y, msg.pose.position.z],
                        'orientation': [msg.pose.orientation.x, msg.pose.orientation.y, msg.pose.orientation.z,
                                        msg.pose.orientation.w]
                    })
                    print("Extract Imu {}".format(timestr))
                    # if not os.path.exists(imupath):
                    #     self.to_imu(imupath, msg)
                    self.to_imu(imus)

    def to_imu(self, imus):
        for timestamp, imuinfo in imus.iteritems():
            imupath = os.path.join(self.dst, 'current_pose_'+str(timestamp)+'.imu')
            if os.path.exists(imupath):
                continue
            with open(imupath, 'w') as f:
                f.write(json.dumps(imuinfo, sort_keys=True, indent=4, separators=(',', ':')))

    def to_pcd_binary(self, pcdpath, msg):
        f = open(pcdpath, 'wb')
        lidar = list(pc2.read_points(msg))
        header = copy.deepcopy(PCD_BINARY_TEMPLATE).format(len(lidar), len(lidar))
        f.write(header.encode())
        # f.write(binary)
        import struct
        for pi in lidar:
            # import pdb;pdb.set_trace()
            h = struct.pack('<fffIId', pi[0], pi[1], pi[2], pi[3], pi[4], pi[5])
            f.write(h)
        f.close()
    def to_548_binary(self,pcdpath,msg):
        f = open(pcdpath, 'wb')
        lidar = list(pc2.read_points(msg))
        header = copy.deepcopy(ARS548_PCD_BINARY_TEMPLATE).format(len(lidar), len(lidar))
        f.write(header.encode())
        # f.write(binary)
        import struct
        for pi in lidar:
            # import pdb;pdb.set_trace()
            h = struct.pack('<fffff', pi[0], pi[1], pi[2], pi[3], pi[4])
            f.write(h)
        f.close()
    def to_arbe_binary(self,pcdpath,msg):
        f = open(pcdpath, 'wb')
        lidar = list(pc2.read_points(msg))
        header = copy.deepcopy(ARS548_PCD_BINARY_TEMPLATE).format(len(lidar), len(lidar))
        f.write(header.encode())
        # f.write(binary)
        import struct
        for pi in lidar:
            # import pdb;pdb.set_trace()
            # print(pi)
            h = struct.pack('<fffff', pi[0], pi[1], pi[2], pi[8], pi[7])
            f.write(h)
        f.close()
    

    def to_pcd_ascii(self, pcdpath, msg):
        f = open(pcdpath, 'w')
        lidar = list(pc2.read_points(msg))
        header = copy.deepcopy(PCD_ASCII_TEMPLATE).format(len(lidar), len(lidar))
        f.write(header.encode())
        # f.write(binary)
        for pi in lidar:
            f.write(' '.join([repr(p) for p in pi])+'\n')
        f.close()

    def gen_pic(self, imgpath, msg):
        color = np.ndarray(msg.serialize, dtype=np.uint8).reshape(msg.height, msg.width, 3)
        # color = cv2.cvtColor(data, cv2.COLOR_BGR2RGB)
        # with open(img_file, "wb") as f:
        cv2.imencode('.jpg', color)[1].tofile(imgpath)
        # cv2.imwrite(img_file, color)

    def to_no_img(self,imgpath, msg):
        try:
            # color = np.ndarray(msg.serialize, dtype=np.uint8)
            cv_image = self.bridge.imgmsg_to_cv2(msg, desired_encoding='bgr8')
            # im = cv2.cvtColor(cv_image, cv2.COLOR_BayerRG2RGB)
        except CvBridgeError as e:
            print
            e
        cv2.imwrite(imgpath, cv_image)  # 保存；

    def to_img(self,imgpath, msg):
        try:
            cv_image = self.bridge.compressed_imgmsg_to_cv2(msg, desired_encoding='bgr8')
            im = cv2.cvtColor(cv_image, cv2.COLOR_BayerRG2RGB)
        except CvBridgeError as e:
            print
            e
        cv2.imwrite(imgpath, im)  # 保存；

def batch_extract(input, output):
    errors = []
    for bag in os.listdir(input):
        bag_file = os.path.join(input, bag)
        if os.path.isdir(bag_file):
            continue
        try:
            # bag_file = 'D:\\20210304_172611.bag' # bag文件路径
            name, _ = bag.rsplit('.', 2)
            dst = os.path.join(output, name)
            if not os.path.exists(dst):
                os.makedirs(dst)
            # dst = '/input/output' # 结果路径
            extractor = BagExtractor(bag_file, dst)
            extractor.extract_img_pcd(['/ca_front_left/image_raw/compressed', '/ca_front_right/image_raw/compressed', '/ca_left/image_raw/compressed',
                               '/ca_right/image_raw/compressed', '/ca_back/image_raw/compressed'], 
                               ['/rslidar_car_points', '/rslidar_rode_points'], 
                               '/current_pose')
        except Exception:
            errors.append(bag)
            continue
    errors_file = os.path.join(input, 'errors.txt')
    with open(errors_file, 'w') as f:
        for error in errors:
            f.write(error + '\n')
    print('***************')
    print('\n'.join(errors))

def onec_extract(bag_file, output):
    no_cam = ['/ca_front_left/image_raw']
    cam = []
    extractor = BagExtractor(bag_file, output)
    extractor.extract_img_pcd(no_cam, cam, ['/rslidar_points'], "/ars548/cloudpoints",'/arbe/rviz/pointcloud_0', '/current_pose')


class MyProcessing(mp.Process):
    def __init__(self, task_queue, error_queue):
        self.tasks = task_queue
        self.errors = error_queue
        super(MyProcessing, self).__init__()

    def run(self):
        while True:
            if self.tasks.empty():
                print(self.pid, 'process is dead')
                break
            bag_file, dst = self.tasks.get()
                # import pdb;pdb.set_trace()
            try:
                onec_extract(bag_file, dst)
            except Exception as e:
                logger.error(traceback.format_exc())
                self.errors.put(bag_file)


def muli_extract(input, output):
    task_queue = mp.Queue()
    error_queue = mp.Queue()
    # os.walk() 遍历
    for path, dirs, files in os.walk(input):
        for file in files:
            if not file.endswith('.bag'):
                continue
            bag_file = os.path.join(path, file)
            print("find bag file : "+ bag_file)
            # bag_file = 'D:\\20210304_172611.bag' # bag文件路径
            name, _ = file.rsplit('.', 2)
            relpath = os.path.relpath(bag_file, input)
            dst = os.path.join(output, os.path.dirname(relpath), name)
            if not os.path.exists(dst):
                os.makedirs(dst)

            task_queue.put((bag_file, dst))

    process = []
    for _ in range(PROCESS_NUM):
        pro = MyProcessing(task_queue, error_queue)
        # import pdb;pdb.set_trace()
        pro.start() # 启动进程任务，等待CPU调度执行
        process.append(pro)
        # pro.join()  # 阻塞至子进程运行完毕
    for p in process:
        p.join()

    errors_file = os.path.join(input, 'errors.txt')
    errors = []
    with open(errors_file, 'w') as f:
        while True:
            try:
                if error_queue.empty():
                    break
                error = error_queue.get()
                errors.append(error)
                f.write(error+'\n')
            except Exception:
                print("写入错误文件失败！！！")
                break

    print('***************')
    print('\n'.join(errors))



if __name__ == '__main__':
    # onec_extract('/input/2022-01-08-14-38-35.bag', '/output/2022-01-08-14-38-35')
    # src = raw_input("请输入bag包存放路径：").replace("'", '').strip()
    # dst = raw_input("请输入解析结果路径：").replace("'", '').strip()
    # mp.set_start_method('spawn')  # 选择启动方式
    src = './input'
    dst = './output'
    # onec_extract('/input/2022-01-17-16-01-49.bag', '/output')
    muli_extract(src, dst)
    # import extract
    # extract.extract()




    # bag_file = '/input/2021-12-22-14-52-05.bag' # bag文件路径
    # dst = '/output/2021-12-22-14-52-05' # 结果路径
    # extractor = BagExtractor(bag_file, dst)
    # extractor.extract_img_pcd(['/camera1/image_raw/compressed', '/camera2/image_raw/compressed', '/camera3/image_raw/compressed',
    #                            '/camera4/image_raw/compressed', '/camera5/image_raw/compressed'], ['/rslidar_left_points', '/rslidar_right_points'], '/current_pose')
